#ifndef L2_H_
#define L2_H_
 
#include <algorithm>
#include <cmath>
#include <cstdint>
#include <iterator>
#include <limits>
#include <type_traits>
#include <utility>
#include <vector>
 
#define SCANN_API_PUBLIC __attribute__((visibility("default")))
 
namespace hw_alg {
 
SCANN_API_PUBLIC inline __attribute__((always_inline)) uint32_t L2sqr_u8(const uint8_t* x, const uint8_t* __restrict y, const size_t d) {
    size_t i = 0;
    uint32_t res;
    const size_t single_round = 16;
    const size_t double_round = 64;
    uint32x4_t res1 = vdupq_n_u32(0);
    uint32x4_t res2 = vdupq_n_u32(0);
    uint32x4_t res3 = vdupq_n_u32(0);
    uint32x4_t res4 = vdupq_n_u32(0);
    for (i = 0; i + double_round <= d; i += double_round) {
        const uint8x16_t x8_0 = vld1q_u8(x + i);
        const uint8x16_t x8_1 = vld1q_u8(x + i + 16);
        const uint8x16_t x8_2 = vld1q_u8(x + i + 32);
        const uint8x16_t x8_3 = vld1q_u8(x + i + 48);
 
        const uint8x16_t y8_0 = vld1q_u8(y + i);
        const uint8x16_t y8_1 = vld1q_u8(y + i + 16);
        const uint8x16_t y8_2 = vld1q_u8(y + i + 32);
        const uint8x16_t y8_3 = vld1q_u8(y + i + 48);
 
        const uint8x16_t d8_0 = vabdq_u8(x8_0, y8_0); 
        const uint8x16_t d8_1 = vabdq_u8(x8_1, y8_1);
        const uint8x16_t d8_2 = vabdq_u8(x8_2, y8_2);
        const uint8x16_t d8_3 = vabdq_u8(x8_3, y8_3);
 
        res1 = vdotq_u32(res1, d8_0, d8_0);
        res2 = vdotq_u32(res2, d8_1, d8_1);
        res3 = vdotq_u32(res3, d8_2, d8_2);
        res4 = vdotq_u32(res4, d8_3, d8_3);
    }
    for (; i + single_round <= d; i += single_round) {
        const uint8x16_t x8_0 = vld1q_u8(x + i);
        const uint8x16_t y8_0 = vld1q_u8(y + i);
 
        const uint8x16_t d8_0 = vabdq_u8(x8_0, y8_0); 
        res1 = vdotq_u32(res1, d8_0, d8_0);
    }
    res1 = vaddq_u32(res1, res2);
    res3 = vaddq_u32(res3, res4);
    res1 = vaddq_u32(res1, res3);
    res = vaddvq_u32(res1);
    for (; i < d; i++) {
        const int32_t tmp = x[i] - y[i];
        res += tmp * tmp;
    }
    return res;
}
 
SCANN_API_PUBLIC inline __attribute__((always_inline)) void L2sqr_batch4_u8InOne(
        const uint8_t*  x,
        const uint8_t* __restrict y,
        const size_t num_rows, 
        const size_t d,
        float* dis) {
    size_t i;
    const size_t single_round = 16;
    int nid = 0;
    for (; nid + 4 <= num_rows; nid += 4) {
        uint8_t* y0 = (uint8_t*)y + nid * d;
        uint8_t* y1 = y0 + d;
        uint8_t* y2 = y1 + d;
        uint8_t* y3 = y2 + d;
        uint32x4_t neon_res1 = vdupq_n_u32(0);
        uint32x4_t neon_res2 = vdupq_n_u32(0);
        uint32x4_t neon_res3 = vdupq_n_u32(0);
        uint32x4_t neon_res4 = vdupq_n_u32(0);
 
        if (d >= 2 * single_round) {
            uint8x16_t neon_query = vld1q_u8(x);
            uint8x16_t neon_base1 = vld1q_u8(y0);
            uint8x16_t neon_base2 = vld1q_u8(y1);
            uint8x16_t neon_base3 = vld1q_u8(y2);
            uint8x16_t neon_base4 = vld1q_u8(y3);
 
            uint8x16_t neon_diff1 = vabdq_u8(neon_base1, neon_query);
            uint8x16_t neon_diff2 = vabdq_u8(neon_base2, neon_query);
            uint8x16_t neon_diff3 = vabdq_u8(neon_base3, neon_query);
            uint8x16_t neon_diff4 = vabdq_u8(neon_base4, neon_query);
 
            neon_query = vld1q_u8(x + single_round);
            neon_base1 = vld1q_u8(y0 + single_round);
            neon_base2 = vld1q_u8(y1 + single_round);
            neon_base3 = vld1q_u8(y2 + single_round);
            neon_base4 = vld1q_u8(y3 + single_round);
 
            neon_res1 = vdotq_u32(neon_res1, neon_diff1, neon_diff1);
            neon_res2 = vdotq_u32(neon_res2, neon_diff2, neon_diff2);
            neon_res3 = vdotq_u32(neon_res3, neon_diff3, neon_diff3);
            neon_res4 = vdotq_u32(neon_res4, neon_diff4, neon_diff4);
 
            for (i = 2 * single_round; i <= d - single_round; i += single_round) {
                neon_diff1 = vabdq_u8(neon_base1, neon_query);
                neon_diff2 = vabdq_u8(neon_base2, neon_query);
                neon_diff3 = vabdq_u8(neon_base3, neon_query);
                neon_diff4 = vabdq_u8(neon_base4, neon_query);
 
                neon_query = vld1q_u8(x + i);
                neon_base1 = vld1q_u8(y0 + i);
                neon_base2 = vld1q_u8(y1 + i);
                neon_base3 = vld1q_u8(y2 + i);
                neon_base4 = vld1q_u8(y3 + i);
 
                neon_res1 = vdotq_u32(neon_res1, neon_diff1, neon_diff1);
                neon_res2 = vdotq_u32(neon_res2, neon_diff2, neon_diff2);
                neon_res3 = vdotq_u32(neon_res3, neon_diff3, neon_diff3);
                neon_res4 = vdotq_u32(neon_res4, neon_diff4, neon_diff4);
            }
            neon_diff1 = vabdq_u8(neon_base1, neon_query);
            neon_diff2 = vabdq_u8(neon_base2, neon_query);
            neon_diff3 = vabdq_u8(neon_base3, neon_query);
            neon_diff4 = vabdq_u8(neon_base4, neon_query);
 
            neon_res1 = vdotq_u32(neon_res1, neon_diff1, neon_diff1);
            neon_res2 = vdotq_u32(neon_res2, neon_diff2, neon_diff2);
            neon_res3 = vdotq_u32(neon_res3, neon_diff3, neon_diff3);
            neon_res4 = vdotq_u32(neon_res4, neon_diff4, neon_diff4);
 
            neon_res1  = vpaddq_u32(neon_res1 , neon_res2);
            neon_res3  = vpaddq_u32(neon_res3 , neon_res4);
            neon_res1  = vpaddq_u32(neon_res1 , neon_res3);
            vst1q_f32(dis + nid, vcvtq_f32_u32(neon_res1));
        } else if (d >= single_round){
            uint8x16_t neon_query = vld1q_u8(x);
            uint8x16_t neon_base1 = vld1q_u8(y0);
            uint8x16_t neon_base2 = vld1q_u8(y1);
            uint8x16_t neon_base3 = vld1q_u8(y2);
            uint8x16_t neon_base4 = vld1q_u8(y3);
 
            uint8x16_t neon_diff1 = vabdq_u8(neon_base1, neon_query);
            uint8x16_t neon_diff2 = vabdq_u8(neon_base2, neon_query);
            uint8x16_t neon_diff3 = vabdq_u8(neon_base3, neon_query);
            uint8x16_t neon_diff4 = vabdq_u8(neon_base4, neon_query);
 
            neon_res1 = vdotq_u32(neon_res1, neon_diff1, neon_diff1);
            neon_res2 = vdotq_u32(neon_res2, neon_diff2, neon_diff2);
            neon_res3 = vdotq_u32(neon_res3, neon_diff3, neon_diff3);
            neon_res4 = vdotq_u32(neon_res4, neon_diff4, neon_diff4);
 
            neon_res1  = vpaddq_u32(neon_res1 , neon_res2 );
            neon_res3  = vpaddq_u32(neon_res3 , neon_res4 );
            neon_res1  = vpaddq_u32(neon_res1 , neon_res3 );
 
            vst1q_f32(dis + nid, vcvtq_f32_u32(neon_res1));
            i = single_round;
        } else {
            memset(dis + nid, 0, sizeof(float) * 4);
            i = 0;
        }
        if(i < d){
            float q0 = x[i] - *(y0 + i);
            float q1 = x[i] - *(y1 + i);
            float q2 = x[i] - *(y2 + i);
            float q3 = x[i] - *(y3 + i);
            float d0 = q0 * q0;
            float d1 = q1 * q1;
            float d2 = q2 * q2;
            float d3 = q3 * q3;
            for (i++; i < d; ++i) {
                q0 = x[i] - *(y0 + i);
                q1 = x[i] - *(y1 + i);
                q2 = x[i] - *(y2 + i);
                q3 = x[i] - *(y3 + i);
                d0 += q0 * q0;
                d1 += q1 * q1;
                d2 += q2 * q2;
                d3 += q3 * q3;
            }
            dis[nid] += d0;
            dis[nid + 1] += d1;
            dis[nid + 2] += d2;
            dis[nid + 3] += d3;
        }
    }
    for (; nid < num_rows; ++nid) {
      dis[nid] = L2sqr_u8(x, y + nid * d, d);
    }
}
} // namespace hw_alg
 
#endif